{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DQJauk8G0eXw"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "B_NBCOPt0ipM"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QAFKDBS6L-bv"
      },
      "source": [
        "# Unlocking Gemma's Power: Data-Parallel Inference on TPUs with JAX\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_1]data_parallel_inference_in_jax_tpu.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G6FSzFjJdkQ-"
      },
      "source": [
        "## **Intro**\n",
        "\n",
        "- This notebook demonstrates how to leverage [TPUs](https://www.kaggle.com/docs/tpu) and [JAX](https://jax.readthedocs.io/en/latest/) for **data-parallel inference** with the [Gemma](https://blog.google/technology/developers/gemma-open-models/) large language model.\n",
        "\n",
        "- This tutorial helps you tackling various movie review tasks simultaneously within a single prompt. Imagine identifying key characters, summarizing plots, and classifying genres of a movie – all at blazing-fast speeds!\n",
        "\n",
        "- While this tutorial emphasizes movie reviews, the core concepts of data-parallel inference with Gemma extend to various real-world applications."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0d02FvvA0a67"
      },
      "source": [
        "### **JAX Data Parallel Inference**\n",
        "\n",
        "[JAX supports data parallelism](https://jax.readthedocs.io/en/latest/distributed_data_loading.html#data-parallelism) for efficient inference on multiple devices (TPUs, GPUs, etc.). In this approach, each device holds a replica of the model and processes a separate chunk of the input data (per-replica batch). This distribution reduces inference time for large datasets compared to running on a single device.\n",
        "\n",
        "**Key Points:**\n",
        "- Each device has a copy of the model.\n",
        "- Data is split into per-replica batches distributed across devices.\n",
        "- JAX automatically handles data distribution, so you don't need to worry about the order in which data lands on each device.\n",
        "This simplifies data loading: each device can independently receive its per-replica batch stream."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4M2TSq9ByVjE"
      },
      "source": [
        "### **Let's get started!**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pcAAkx8LKaHy"
      },
      "source": [
        "## Connect to a TPU:\n",
        "- To connect to a TPU v2, click on the button Connect TPU in the top right-hand corner of the screen."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LXUWFVxV8gTI"
      },
      "source": [
        "You can now run the following code cell to see the TPU devices you have available:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1XWpUuCgo7kd"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
              " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
              " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
              " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
              " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
              " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
              " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
              " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
            ]
          },
          "execution_count": 1,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import jax\n",
        "\n",
        "jax.local_devices()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dDjG76wZOvCb"
      },
      "source": [
        "Awesome! Our setup includes a TPU with 8 cores. This notebook will take advantage of this by splitting our workload across all cores (data parallelism). Each core will receive a fraction of the data (1/8th) and generate results simultaneously."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sy2EwL0x8gTK"
      },
      "source": [
        "## Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LjQ4e57vmWCP"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "from flax import jax_utils\n",
        "from flax.training.common_utils import shard\n",
        "from transformers import FlaxGemmaForCausalLM, AutoTokenizer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DRqIBQOmA7Nc"
      },
      "source": [
        "## How to Access Gemma:\n",
        "\n",
        "Before using Gemma for the first time, you need to request access to the model through [Hugging Face](https://huggingface.co/google/gemma-2b). This ensures you've accepted the model's terms-of-use.\n",
        "\n",
        "Since you'll be downloading the Gemma model weights from the Hugging Face Hub, you'll need a Hugging Face token to verify your acceptance.\n",
        "\n",
        "If you don't already have a Hugging Face account, you can register for one at [Hugging Face](https://huggingface.co/join). Once you have an account, follow these steps:\n",
        "\n",
        "1. Go to the [Hugging Face Gemma Model Card](https://huggingface.co/google/gemma-2b) and select Request Access.\n",
        "2. Complete the consent form and accept the terms and conditions.\n",
        "3. Go to [Hugging Face Hub account settings](https://huggingface.co/settings/tokens) and create a new access token.\n",
        "3. Copy your Token.\n",
        "4. Then, in Colab, select **Secrets** (🔑) in the left pane and add your Token name (choose a secure name like `hugging_face_token_key`) and store your Token value under that name."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qE4YuCo-8gTL"
      },
      "outputs": [],
      "source": [
        "from google.colab import userdata\n",
        "import os\n",
        "\n",
        "try:\n",
        "    access_token = userdata.get('hugging_face_token_key')\n",
        "except ImportError:\n",
        "    access_token = os.environ['hugging_face_token_key']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rb2VBFGJQ1yx"
      },
      "source": [
        "## Load the Model\n",
        "You will use the latest [Gemma-2B](https://huggingface.co/google/gemma-1.1-2b-it), this model offers 2 billion parameters, ensuring a lightweight footprint.\n",
        "\n",
        "The Gemma model can be loaded using the familiar [`from_pretrained`](https://huggingface.co/docs/transformers/v4.38.1/en/main_classes/model#transformers.FlaxPreTrainedModel.from_pretrained) method in Transformers. This method downloads the model weights from the Hugging Face Hub the first time it is called, and subsequently initializes the Gemma model using these weights.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KMHP_dDD8gTL"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "59e82a5d3ce54e0ebe009f8cc985f11b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/618 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2374e7a7e41b4bf58d3b620607f2103d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "flax_model.msgpack:   0%|          | 0.00/5.01G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some of the weights of FlaxGemmaForCausalLM were initialized in bfloat16 precision from the model checkpoint at google/gemma-1.1-2b-it:\n",
            "[('model', 'embed_tokens', 'embedding'), ('model', 'layers', '0', 'input_layernorm', 'weight'), ('model', 'layers', '0', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '0', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '0', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '0', 'post_attention_layernorm', 'weight'), ('model', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '0', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '1', 'input_layernorm', 'weight'), ('model', 'layers', '1', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '1', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '1', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '1', 'post_attention_layernorm', 'weight'), ('model', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '1', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '10', 'input_layernorm', 'weight'), ('model', 'layers', '10', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '10', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '10', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '10', 'post_attention_layernorm', 'weight'), ('model', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '10', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '11', 'input_layernorm', 'weight'), ('model', 'layers', '11', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '11', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '11', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '11', 'post_attention_layernorm', 'weight'), ('model', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '11', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '12', 'input_layernorm', 'weight'), ('model', 'layers', '12', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '12', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '12', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '12', 'post_attention_layernorm', 'weight'), ('model', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '12', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '13', 'input_layernorm', 'weight'), ('model', 'layers', '13', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '13', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '13', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '13', 'post_attention_layernorm', 'weight'), ('model', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '13', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '14', 'input_layernorm', 'weight'), ('model', 'layers', '14', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '14', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '14', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '14', 'post_attention_layernorm', 'weight'), ('model', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '14', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '15', 'input_layernorm', 'weight'), ('model', 'layers', '15', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '15', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '15', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '15', 'post_attention_layernorm', 'weight'), ('model', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '15', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '16', 'input_layernorm', 'weight'), ('model', 'layers', '16', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '16', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '16', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '16', 'post_attention_layernorm', 'weight'), ('model', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '16', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '17', 'input_layernorm', 'weight'), ('model', 'layers', '17', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '17', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '17', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '17', 'post_attention_layernorm', 'weight'), ('model', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '17', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '2', 'input_layernorm', 'weight'), ('model', 'layers', '2', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '2', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '2', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '2', 'post_attention_layernorm', 'weight'), ('model', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '2', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '3', 'input_layernorm', 'weight'), ('model', 'layers', '3', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '3', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '3', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '3', 'post_attention_layernorm', 'weight'), ('model', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '3', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '4', 'input_layernorm', 'weight'), ('model', 'layers', '4', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '4', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '4', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '4', 'post_attention_layernorm', 'weight'), ('model', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '4', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '5', 'input_layernorm', 'weight'), ('model', 'layers', '5', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '5', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '5', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '5', 'post_attention_layernorm', 'weight'), ('model', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '5', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '6', 'input_layernorm', 'weight'), ('model', 'layers', '6', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '6', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '6', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '6', 'post_attention_layernorm', 'weight'), ('model', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '6', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '7', 'input_layernorm', 'weight'), ('model', 'layers', '7', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '7', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '7', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '7', 'post_attention_layernorm', 'weight'), ('model', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '7', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '8', 'input_layernorm', 'weight'), ('model', 'layers', '8', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '8', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '8', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '8', 'post_attention_layernorm', 'weight'), ('model', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '8', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('model', 'layers', '9', 'input_layernorm', 'weight'), ('model', 'layers', '9', 'mlp', 'down_proj', 'kernel'), ('model', 'layers', '9', 'mlp', 'gate_proj', 'kernel'), ('model', 'layers', '9', 'mlp', 'up_proj', 'kernel'), ('model', 'layers', '9', 'post_attention_layernorm', 'weight'), ('model', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('model', 'layers', '9', 'self_attn', 'o_proj', 'kernel'), ('model', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('model', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('model', 'norm', 'weight')]\n",
            "You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6dabd5e1fe604820851293485f0d8dfa",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# You will use the latest Gemma 1.1 2B (IT), an update over the original instruction-tuned Gemma release.\n",
        "model_id = \"google/gemma-1.1-2b-it\"\n",
        "\n",
        "# Load the model with desired data type (bfloat16 for reduced memory usage)\n",
        "model, params = FlaxGemmaForCausalLM.from_pretrained(model_id, revision=\"flax\", _do_init=False, dtype=jnp.bfloat16, token=access_token)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8nSmjYj88gTM"
      },
      "source": [
        "You see a warning that the model parameters were loaded in bfloat16 precision - this is fine since you also want to keep the parameters in bfloat16 for inference."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GQMjMGOLeHF_"
      },
      "source": [
        "The corresponding tokenizer can now be loaded using a similar API:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SOgY8-DroH4y"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "3332ea7d742b461ea844c26d8e42f672",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9f42723a19c34b6ca395cc9391edcfe5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "abdc6b6658bd4a15bc7a9622d26e440b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "87a5df6d22154d259d6aac4d1df904c6",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BVz3ylFwepfU"
      },
      "source": [
        "## Define Inputs\n",
        "\n",
        "Next, you will define the text inputs. Since you have 8 TPU cores over which you want to perform data parallelism, you need our batch size to be a multiple of 8. This is to ensure that each TPU core receives the same amount of data (`bsz / 8` samples). You will change the input text later."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-XciqMHmsQqE"
      },
      "outputs": [],
      "source": [
        "input_text = 8 * [\"What year was the movie Titanic made?\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QuePjDn6sO3_"
      },
      "source": [
        "You can pre-process our input text to token ids using the tokenizer. TPUs expect inputs of static shape, so you'll define our maximum prompt length to be 64, and always pad our inputs to this sequence length:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X9a5KeFUoE6_"
      },
      "outputs": [],
      "source": [
        "max_input_length = 64\n",
        "\n",
        "inputs = tokenizer(\n",
        "    input_text,\n",
        "    padding=\"max_length\",\n",
        "    max_length=max_input_length,\n",
        "    return_attention_mask=True,\n",
        "    return_tensors=\"np\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eW3iFWNeUAxX"
      },
      "source": [
        "You now need to copy the model parameters to each TPU core. Each core will hold it's own copy of the parameters, such that it can run a model generation in parallel with the others. Copying the parameters across devices is achieved simply with the [`replicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.jax_utils.html#flax.jax_utils.replicate) method from Flax."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aMA1qL-5ofwq"
      },
      "outputs": [],
      "source": [
        "params = jax_utils.replicate(params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xaygh3WvUh0G"
      },
      "source": [
        "Similarly, you need to split (or shard) our inputs across TPU cores. Sharding our inputs is achieved with the Flax helper function [`shard`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.common_utils.shard):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yEUlaMepUi1M"
      },
      "outputs": [],
      "source": [
        "inputs = shard(inputs.data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qfb-MXRPUZY5"
      },
      "source": [
        "# Inference\n",
        "\n",
        "You can now define our data-parallel method for inference. The Transformers [`generate`](https://huggingface.co/docs/transformers/v4.38.1/en/main_classes/text_generation#transformers.FlaxGenerationMixin.generate) method provides functionality for auto-regressive generation with batching, sampling, beam-search, etc. To reap the benefits of JAX, you'll compile the generate method end-to-end, such that the operations are fused into XLA-optimised kernels and executed efficiently on our hardware accelerator.\n",
        "\n",
        "To achieve this, you'll wrap the `generate` method with the [`jax.pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) transformation. The `jax.pmap` transformation compiles the `generate` method with XLA, and prepares a function that can be executed in parallel across TPU devices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tOE2UjXqUZvi"
      },
      "outputs": [],
      "source": [
        "def generate(inputs, params, max_new_tokens):\n",
        "    generated_ids = model.generate(\n",
        "        inputs[\"input_ids\"],\n",
        "        attention_mask=inputs[\"attention_mask\"],\n",
        "        params=params,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "        do_sample=True,\n",
        "    )\n",
        "    return generated_ids.sequences\n",
        "\n",
        "p_generate = jax.pmap(\n",
        "    generate, \"inputs\", in_axes=(0, 0, None,), out_axes=0, static_broadcasted_argnums=(2,)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cepVL9tne2jg"
      },
      "source": [
        "To avoid re-compiling the generate function for different values of `max_new_tokens`, you'll define it as a global variable here, and pass it to the generate function each time:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uVf9xbJBoYUG"
      },
      "outputs": [],
      "source": [
        "max_new_tokens = 128"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eTC7I3bmWXgL"
      },
      "source": [
        "You can now compile our parallel generate function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ws1XC2PGHBYu"
      },
      "outputs": [],
      "source": [
        "_ = p_generate(inputs, params, max_new_tokens)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XYxnfADQWqxs"
      },
      "source": [
        "Now that the function is compiled, you can run it again much faster using the optimised kernels:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g1JbDhJQmh_w"
      },
      "outputs": [],
      "source": [
        "generated_ids = p_generate(inputs, params, max_new_tokens)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qDOUrhKMW2cT"
      },
      "source": [
        "After generate function compiled, the model outputs token IDs, which are then decoded by the tokenizer back into human-readable text:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z-bUqMoJs6t1"
      },
      "outputs": [],
      "source": [
        "generated_ids = jax.device_get(generated_ids.reshape(-1, generated_ids.shape[-1]))\n",
        "pred_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Axw6Aqla0a7B"
      },
      "source": [
        "## Analyze Movie Reviews with Parallel Processing on TPUs\n",
        "\n",
        "This code demonstrates analyzing film critiques using multiple tasks processed concurrently on TPUs. The model performs the following tasks on each movie review:\n",
        "\n",
        "* **Identify Key Characters:** Find the two main characters in the film.\n",
        "* **Summarize Plot:** Briefly condense the story's key points.\n",
        "* **Predict Genre:** Classify the film genre based on the critique (e.g., comedy, drama).\n",
        "* **Recommend Similar Films:** Suggest two films with similar titles or themes.\n",
        "\n",
        "**Parallel Processing Power:**\n",
        "\n",
        "This approach leverages all available TPU cores to analyze multiple movies and tasks simultaneously. This significantly speeds up the analysis compared to processing each movie and task sequentially."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hJyX1-bw8gTP"
      },
      "outputs": [],
      "source": [
        "# Creates formatted input text for multiple movies and tasks.\n",
        "def create_movie_inputs(movie_titles, tasks):\n",
        "  \"\"\"\n",
        " Args:\n",
        "    movie_titles: List of movie titles with optional year information.\n",
        "    tasks: List of task descriptions with emphasized objectives.\n",
        "\n",
        "  Returns:\n",
        "    A list of formatted input text for the model.\n",
        "  \"\"\"\n",
        "  inputs = []\n",
        "  for title in movie_titles:\n",
        "    for task in tasks:\n",
        "      formatted_task = f\"\\n**Task:** {task}\\n=============\"\n",
        "      inputs.append(f\"\\n=============**Movie Title:**{title}{formatted_task}\")\n",
        "  return inputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j9Ygu1eZ8gTQ"
      },
      "outputs": [],
      "source": [
        "# Define movie titles with optional year information\n",
        "movie_titles = [\"Titanic (1997 film)\", \"Avatar (2009 film)\"]\n",
        "\n",
        "# Define tasks to be performed for each movie (one prompt per task)\n",
        "tasks = [\n",
        "    \"Main Characters (2 max): Who are the 2 key characters?\",\n",
        "    \"Plot Summary (1 sentence): Briefly summarize the story.\",\n",
        "    \"Genre: What is the most likely genre (e.g. science fiction, comedy, drama)?\",\n",
        "    \"Recommendation: Suggest 2 movies with similar titles or themes.\"\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vDgtekfC2_pY"
      },
      "outputs": [],
      "source": [
        "# Create formatted input text for the model, combining movie titles and tasks\n",
        "input_text = create_movie_inputs(movie_titles, tasks)\n",
        "\n",
        "# Tokenize the input text for the model using the provided tokenizer\n",
        "inputs = tokenizer(input_text, padding=\"max_length\", max_length=max_input_length, return_attention_mask=True, return_tensors=\"np\")\n",
        "\n",
        "# Shard the inputs for data parallelism across multiple TPUs\n",
        "inputs = shard(inputs.data)\n",
        "\n",
        "# Get the generated IDs back from the TPU device and reshape\n",
        "generated_ids = p_generate(inputs, params, max_new_tokens)\n",
        "generated_ids = jax.device_get(generated_ids.reshape(-1, generated_ids.shape[-1]))\n",
        "\n",
        "# Decode the generated IDs back into text using the tokenizer\n",
        "pred_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gvqg1Ens8gTQ"
      },
      "source": [
        "Now, the model has finished its predictions using all 8 TPU cores. Let's see what it generated for each task and each movie."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oC4hMvvs0a7H"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "=============**Movie Title:**Titanic (1997 film)\n",
            "**Task:Main Characters (2 max): Who are the 2 key characters?\n",
            "=============**\n",
            "\n",
            "In the movie Titanic (1997), there are two key characters that drive the plot forward: Jack Dawson and Rose DeWitt Bukater. Jack is a charming and wealthy young man from a poor background who dreams of escaping his humble life and achieving greatness. Rose is a strong-willed and independent young woman who represents the aspirations and resilience of women in the 19th century.\n",
            "\n",
            "=============**Movie Title:**Titanic (1997 film)\n",
            "**Task:Plot Summary (1 sentence): Briefly summarize the story.\n",
            "=============**\n",
            "\n",
            "In the year 1997, James Cameron crafts a tale of love, loss, and survival aboard the luxurious cruise ship RMS Titanic. With its opulent grand staircase, lavish cabins, and doomed romance, the film explores themes of class distinction, human capacity, and the fragility of life.\n",
            "\n",
            "=============**Movie Title:**Titanic (1997 film)\n",
            "**Task:Genre: What is the most likely genre (e.g. science fiction, comedy, drama)?\n",
            "=============**\n",
            "\n",
            "The provided text suggests that the genre of the movie Titanic (1997) is likely to be **drama**.\n",
            "\n",
            "=============**Movie Title:**Titanic (1997 film)\n",
            "**Task:Recommendation: Suggest 2 movies with similar titles or themes.\n",
            "=============**\n",
            "\n",
            "**Movie 1:** The Wolf of Wall Street (2013)\n",
            "**Movie 2:** The Greatest Showman (2017)\n"
          ]
        }
      ],
      "source": [
        "for i in range(4):  # First movie\n",
        "    print(pred_text[i])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "toeMIcG90a7H"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "=============**Movie Title:**Avatar (2009 film)\n",
            "**Task:Main Characters (2 max): Who are the 2 key characters?\n",
            "=============**\n",
            "\n",
            "The main characters of the 2009 film Avatar are Jake Sully and Neytiri.\n",
            "\n",
            "=============**Movie Title:**Avatar (2009 film)\n",
            "**Task:Plot Summary (1 sentence): Briefly summarize the story.\n",
            "=============**\n",
            "\n",
            "In the distant future, Jake Sully and his team of explorers travel to a remote planet called Pandora in search of the native species, the Na'vi. Guided by Neytiri, the Na'vi leader, they soon discover the planet's natural resources are being exploited. In the ensuing battle between colonization and preservation, Jake and Neytiri must make a choice that will determine the fate of Pandora.\n",
            "\n",
            "=============**Movie Title:**Avatar (2009 film)\n",
            "**Task:Genre: What is the most likely genre (e.g. science fiction, comedy, drama)?\n",
            "=============**\n",
            "\n",
            "The genre of the movie Avatar (2009) is most likely science fiction.\n",
            "\n",
            "The movie deals with themes of environmentalism, colonization, and the relationships between humans and nature. It also features advanced technology and special effects that are characteristic of science fiction.\n",
            "\n",
            "=============**Movie Title:**Avatar (2009 film)\n",
            "**Task:Recommendation: Suggest 2 movies with similar titles or themes.\n",
            "=============**\n",
            "\n",
            "**Movie 1:** The Curious Case of Benjamin Button (2008)\n",
            "**Movie 2:** Life of Pi (2012)\n"
          ]
        }
      ],
      "source": [
        "for i in range(4):  # Second movie\n",
        "    print(pred_text[i+4])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bOCUQGwocd_7"
      },
      "source": [
        "## Conclusion"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SBAzXojPcepV"
      },
      "source": [
        "- This notebook showcased efficient multi-task movie reviews using Gemma, TPUs and JAX for parallel inference tasks (character ID, plot, genre, and recommendations) all in a single prompt."
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "name": "[Gemma_1]data_parallel_inference_in_jax_tpu.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
